from keras import Input
from keras.layers import Dropout

from .base_layer import BaseLayer
from config import Config
from entity.model_describe import ModelDescribe


class DropoutLayer(BaseLayer):
    def persistence(self, model_list_line: list, train_id: int) -> ModelDescribe:
        m = super().persistence(model_list_line, train_id)
        rate = Config.DROP_OUT if len(model_list_line) < 2 else model_list_line[1]
        m.var1 = rate
        return m

    def transfer(self, model_list_line: list, inputs: Input, models: list = None):
        # 只写了Dropout没有写rate，用默认Config.DROP_OUT
        rate = Config.DROP_OUT if len(model_list_line) < 2 else model_list_line[1]
        return Dropout(rate)(inputs)